import ast
from collections import defaultdict
import pathlib
import queue
import threading
from typing import Any, Callable, Dict, List, Tuple, Union

import numpy as np
import scipy.stats as stats
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch_geometric.data as gd
from rdkit.Chem import Descriptors, QED
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor
from torch.distributions.dirichlet import Dirichlet
from torch.utils.data import Dataset

from gflownet.models import proxy_model
from gflownet.tasks.seh_frag import SEHFragTrainer
from gflownet.train import FlatRewards, GFNTask, RewardScalar
from gflownet.utils import metrics, sascore
from gflownet.utils.transforms import thermometer


class SEHMOOTask(GFNTask):
    """Sets up a multiobjective task where the rewards are (functions of):
    - the the binding energy of a molecule to Soluble Epoxide Hydrolases.
    - its QED
    - its synthetic accessibility
    - its molecular weight

    The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`.
    """
    def __init__(self, dataset: Dataset, temperature_distribution: str, temperature_parameters: Tuple[float],
                 wrap_model: Callable[[nn.Module], nn.Module] = None):
        self._wrap_model = wrap_model
        self.models = self._load_task_models()
        self.dataset = dataset
        self.temperature_sample_dist = temperature_distribution
        self.temperature_dist_params = temperature_parameters
        self.seeded_preference = None

    def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards:
        return FlatRewards(torch.as_tensor(y))

    def inverse_flat_reward_transform(self, rp):
        return rp

    def _load_task_models(self):
        model = bengio2021flow.load_original_model()
        model, self.device = self._wrap_model(model)
        return {'seh': model}

    def sample_conditional_information(self, n):
        beta = None
        if self.temperature_sample_dist == 'gamma':
            loc, scale = self.temperature_dist_params
            beta = self.rng.gamma(loc, scale, n).astype(np.float32)
            upper_bound = stats.gamma.ppf(0.95, loc, scale=scale)
        elif self.temperature_sample_dist == 'uniform':
            beta = self.rng.uniform(*self.temperature_dist_params, n).astype(np.float32)
            upper_bound = self.temperature_dist_params[1]
        elif self.temperature_sample_dist == 'beta':
            beta = self.rng.beta(*self.temperature_dist_params, n).astype(np.float32)
            upper_bound = 1
        beta_enc = thermometer(torch.tensor(beta), 32, 0, upper_bound)  # TODO: hyperparameters
        if self.seeded_preference is not None:
            preferences = torch.tensor([self.seeded_preference] * n).float()
        elif 0:
            m = Dirichlet(torch.FloatTensor([1.] * 4))
            preferences = m.sample([n])
        else:
            a = np.random.dirichlet([1] * 4, n)
            b = np.random.exponential(1, n)[:, None]
            preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float()
        encoding = torch.cat([beta_enc, preferences], 1)
        return {'beta': torch.tensor(beta), 'encoding': encoding, 'preferences': preferences}

    def encode_conditional_information(self, info):
        # This assumes we're using a constant (max) beta and that info is the preferences
        encoding = torch.cat([torch.ones((len(info), 32)), info], 1)
        return {'beta': torch.ones(len(info)) * self.temperature_dist_params[-1],
                'encoding': encoding.float(),
                'preferences': info.float()}

    def cond_info_to_reward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
        if isinstance(flat_reward, list):
            if isinstance(flat_reward[0], Tensor):
                flat_reward = torch.stack(flat_reward)
            else:
                flat_reward = torch.tensor(flat_reward)
        scalar_reward = (flat_reward * cond_info['preferences']).sum(1)
        return scalar_reward**cond_info['beta']

    def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
        graphs = [proxy_model.mol2graph(i) for i in mols]
        is_valid = torch.tensor([i is not None for i in graphs]).bool()
        if not is_valid.any():
            return FlatRewards(torch.zeros((0, 4))), is_valid
        batch = gd.Batch.from_data_list([i for i in graphs if i is not None])
        batch.to(self.device)
        seh_preds = self.models['seh'](batch).reshape((-1,)).clip(1e-4, 100).data.cpu() / 8
        seh_preds[seh_preds.isnan()] = 0

        def safe(f, x, default):
            try:
                return f(x)
            except Exception:
                return default

        qeds = torch.tensor([safe(QED.qed, i, 0) for i, v in zip(mols, is_valid) if v.item()])
        sas = torch.tensor([safe(sascore.calculateScore, i, 10) for i, v in zip(mols, is_valid) if v.item()])
        sas = (10 - sas) / 9  # Turn into a [0-1] reward
        molwts = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i, v in zip(mols, is_valid) if v.item()])
        molwts = ((300 - molwts) / 700 + 1).clip(0, 1)  # 1 until 300 then linear decay to 0 until 1000
        flat_rewards = torch.stack([seh_preds, qeds, sas, molwts], 1)
        return FlatRewards(flat_rewards), is_valid


class SEHMOOFragTrainer(SEHFragTrainer):
    def default_hps(self) -> Dict[str, Any]:
        return {
            **super().default_hps(),
            'use_fixed_weight': False,
            'num_cond_dim': 32 + 4,  # thermometer encoding of beta + 4 preferences
            'sampling_tau': 0.95,
            'valid_sample_cond_info': False,
            'preference_type': 'dirichlet',
        }

    def setup(self):
        super().setup()
        self.task = SEHMOOTask(self.training_data, self.hps['temperature_sample_dist'],
                               ast.literal_eval(self.hps['temperature_dist_params']), wrap_model=self._wrap_model_mp)
        self.sampling_hooks.append(MultiObjectiveStatsHook(256, self.hps['log_dir']))
        if self.hps['preference_type'] == 'dirichlet':
            valid_preferences = metrics.generate_simplex(4, 5)  # This yields 35 points of dimension 4
        elif self.hps['preference_type'] == 'seeded_single':
            seeded_prefs = np.random.default_rng(142857 + int(self.hps['seed'])).dirichlet([1]*4, 10)
            valid_preferences = seeded_prefs[int(self.hps['single_pref_target_idx'])].reshape((1, 4))
            self.task.seeded_preference = valid_preferences[0]
        elif self.hps['preference_type'] == 'seeded_many':
            valid_preferences = np.random.default_rng(142857 + int(self.hps['seed'])).dirichlet([1]*4, 10)
        self._top_k_hook = TopKHook(10, 128, len(valid_preferences))
        self.test_data = RepeatedPreferenceDataset(valid_preferences, 128)
        self.valid_sampling_hooks.append(self._top_k_hook)

        self.algo.task = self.task

    def build_callbacks(self):
        try:
            from determined.pytorch import PyTorchCallback
        except ImportError:
            PyTorchCallback = object

        parent = self

        class TopKMetricCB(PyTorchCallback):
            def on_validation_end(self, metrics: Dict[str, Any]):
                top_k = parent._top_k_hook.finalize()
                for i in range(len(top_k)):
                    metrics[f'topk_rewards_{i}'] = top_k[i]
                print('validation end', metrics)

        return {'topk': TopKMetricCB()}


class MultiObjectiveStatsHook:
    def __init__(self, num_to_keep: int, log_dir: str, save_every=50):
        # This __init__ is only called in the main process. This object is then (potentially) cloned
        # in pytorch data worker processed and __call__'ed from within those processes. This means
        # each process will compute its own Pareto front, which we will accumulate in the main
        # process by pushing local fronts to self.pareto_queue.
        self.num_to_keep = num_to_keep
        self.all_flat_rewards: List[Tensor] = []
        self.all_smi: List[str] = []
        self.hsri_epsilon = 0.3
        self.compute_hsri = False
        self.compute_normed = False
        self.pareto_queue: mp.Queue = mp.Queue()
        self.pareto_front = None
        self.pareto_front_smi = None
        self.pareto_metrics = mp.Array('f', 4)
        self.stop = threading.Event()
        self.save_every = save_every
        self.log_path = pathlib.Path(log_dir) / 'pareto.pt'
        self.pareto_thread = threading.Thread(target=self._run_pareto_accumulation, daemon=True)
        self.pareto_thread.start()

    def __del__(self):
        self.stop.set()

    def _run_pareto_accumulation(self):
        num_updates = 0
        while not self.stop.is_set():
            try:
                r, smi = self.pareto_queue.get(True, 1)  # Block for a second then check if we've stopped
            except queue.Empty:
                continue
            except ConnectionError:
                break
            if self.pareto_front is None:
                p = self.pareto_front = r
                psmi = smi
            else:
                p = np.concatenate([self.pareto_front, r], 0)
                psmi = self.pareto_front_smi + smi
            idcs = metrics.is_pareto_efficient(-p, False)
            self.pareto_front = p[idcs]
            self.pareto_front_smi = [psmi[i] for i in idcs]
            self.pareto_metrics[0] = metrics.get_hypervolume(torch.tensor(self.pareto_front), zero_ref=True)
            num_updates += 1
            if num_updates % self.save_every == 0:
                torch.save(
                    {
                        'pareto_front': self.pareto_front,
                        'pareto_metrics': list(self.pareto_metrics),
                        'pareto_front_smi': self.pareto_front_smi,
                    }, open(self.log_path, 'wb'))

    def __call__(self, trajs, rewards, flat_rewards, cond_info):
        self.all_flat_rewards = self.all_flat_rewards + list(flat_rewards)
        self.all_smi = self.all_smi + list([i.get('smi', None) for i in trajs])
        if len(self.all_flat_rewards) > self.num_to_keep:
            self.all_flat_rewards = self.all_flat_rewards[-self.num_to_keep:]
            self.all_smi = self.all_smi[-self.num_to_keep:]

        flat_rewards = torch.stack(self.all_flat_rewards).numpy()
        target_min = flat_rewards.min(0).copy()
        target_range = flat_rewards.max(0).copy() - target_min
        hypercube_transform = metrics.Normalizer(
            loc=target_min,
            scale=target_range,
        )
        pareto_idces = metrics.is_pareto_efficient(-flat_rewards, return_mask=False)
        gfn_pareto = flat_rewards[pareto_idces]
        pareto_smi = [self.all_smi[i] for i in pareto_idces]

        self.pareto_queue.put((gfn_pareto, pareto_smi))
        unnorm_hypervolume_with_zero_ref = metrics.get_hypervolume(torch.tensor(gfn_pareto), zero_ref=True)
        unnorm_hypervolume_wo_zero_ref = metrics.get_hypervolume(torch.tensor(gfn_pareto), zero_ref=False)
        info = {
            'UHV with zero ref': unnorm_hypervolume_with_zero_ref,
            'UHV w/o zero ref': unnorm_hypervolume_wo_zero_ref,
        }
        if self.compute_normed:
            normed_gfn_pareto = hypercube_transform(gfn_pareto)
            hypervolume_with_zero_ref = metrics.get_hypervolume(torch.tensor(normed_gfn_pareto), zero_ref=True)
            hypervolume_wo_zero_ref = metrics.get_hypervolume(torch.tensor(normed_gfn_pareto), zero_ref=False)
            info = {
                **info,
                'HV with zero ref': hypervolume_with_zero_ref,
                'HV w/o zero ref': hypervolume_wo_zero_ref,
            }
        if self.compute_hsri:
            upper = np.zeros(gfn_pareto.shape[-1]) + self.hsri_epsilon
            lower = np.ones(gfn_pareto.shape[-1]) * -1 - self.hsri_epsilon
            hsr_indicator = metrics.HSR_Calculator(lower, upper)
            try:
                hsri_w_pareto, x = hsr_indicator.calculate_hsr(-1 * gfn_pareto)
            except Exception:
                hsri_w_pareto = 0
            try:
                hsri_on_flat, _ = hsr_indicator.calculate_hsr(-1 * flat_rewards)
            except Exception:
                hsri_on_flat = 0
            info = {
                **info,
                'hsri_with_pareto': hsri_w_pareto,
                'hsri_on_flat_rew': hsri_on_flat,
            }
        info['lifetime_hv0'] = self.pareto_metrics[0]

        return info


class TopKHook:
    def __init__(self, k, repeats, num_preferences):
        self.queue: mp.Queue = mp.Queue()
        self.k = k
        self.repeats = repeats
        self.num_preferences = num_preferences

    def __call__(self, trajs, rewards, flat_rewards, cond_info):
        self.queue.put([(i['data_idx'], r) for i, r in zip(trajs, rewards)])
        return {}

    def finalize(self):
        data = []
        while not self.queue.empty():
            try:
                data += self.queue.get(True, 1)
            except queue.Empty:
                print("Warning, TopKHook queue timed out!")
                break
        repeats = defaultdict(list)
        for idx, r in data:
            repeats[idx // self.repeats].append(r)
        top_ks = [np.mean(sorted(i)[-self.k:]) for i in repeats.values()]
        assert len(top_ks) == self.num_preferences  # Make sure we got all of them?
        return top_ks


class RepeatedPreferenceDataset:
    def __init__(self, preferences, repeat):
        self.prefs = preferences
        self.repeat = repeat

    def __len__(self):
        return len(self.prefs) * self.repeat

    def __getitem__(self, idx):
        assert 0 <= idx < len(self)
        return torch.tensor(self.prefs[int(idx // self.repeat)])


def main():
    """Example of how this model can be run outside of Determined"""
    hps = {
        'lr_decay': 10000,
        'log_dir': '/scratch/logs/seh_frag_moo/',
        'num_training_steps': 20_000,
        'validate_every': 5,
        'sampling_tau': 0.95,
        'num_layers': 6,
        'num_data_loader_workers': 12,
        'global_batch_size': 256,
        'temperature_dist_params': '(1, 2)',
        'algo': 'TB',
        'sql_alpha': 0.01,
        'seed': 0,
        'preference_type': 'dirichlet',
        
    }
    trial = SEHMOOFragTrainer(hps, torch.device('cuda'))
    trial.verbose = True
    trial.run()


if __name__ == '__main__':
    main()
